import pandas as pd
from tqdm import tqdm
import torch
from transformers import pipeline, AutoTokenizer
import warnings
from transformers import logging
logging.set_verbosity_error()
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

llama_key = open('./llama_key.txt', 'r').read()
device = 'cuda'
# Data import and clean
motn = pd.read_csv('./data/freedom_test.csv', encoding = "ISO-8859-1")
motn = motn[~motn['premise'].isna()]
motn.reset_index(drop = True, inplace = True)

# Instantiate model
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
pipe = pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map=device,
    token = llama_key
)

#############
## Label Docs
#############
user_message = """You are a classifier that determines if a survey answer is about freedom and rights. The survey answer I will show is a response to the question "What does democracy mean to you?" 

Answers about freedom and rights reference "freedom" or "liberty" in the abstract, or reference a particular freedom such as the freedom to express one's opinion, freedom of religion, or other rights mentioned in the Bill of Rights. They may also contain references to freedom from government, freedom of speech, the Second Amendment, references to civil rights or substantive rights, and references to capitalism or free enterprise. 
However, the answer is not about freedom and rights if the answer is about voting rights the "right" to vote, or anything having to do with voting.
I will show you an answer to the question. Read the answer and then determine if the answer is about freedom and rights or not. Here is the respondent's answer:
{}
If the answer is about freedom and rights, return 0. If it is not about freedom and rights, return 1. Do not explain your answer, and only return the number.
"""

data = motn
res = []
for doc in data['premise']:
    messages = [
        {"role": "user", "content": user_message.format(doc)},
    ]
    prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=2, do_sample=False, return_full_text = False, pad_token_id=pipe.tokenizer.eos_token_id)
    res.extend(outputs)

res = [text['generated_text'] for text in res]

# Extract and add labels to df
res = [num[0] for num in res]
motn['llama'] = [1 if '1' in text else 0 for text in res]

motn.to_csv('./data/freedom_test.csv', index = False)